package org.zjvis.datascience.spark.util;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.spark.scheduler.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

/**
 * @description Spark任务监听器
 * @date 2021-12-23
 */
public class SparkListenerInterfaceImpl implements SparkListenerInterface {

    protected final static Logger logger = LoggerFactory.getLogger("SparkListenerInterfaceImpl");

    //jobId和jobInfo的映射
    private Map<Integer, JobInfo> jobIdToJobInfo = new HashMap();
    //剩余权重
    private Integer weight = new Integer(100);

    private Connection conn;

    private Long taskInstanceId;

    private Integer progress = 0;

    public SparkListenerInterfaceImpl(Long taskInstanceId) {
        this.taskInstanceId = taskInstanceId;
        conn = UtilTool.getConn();
    }

    @Data
    @AllArgsConstructor
    class JobInfo {
        private Integer jobId;
        private Integer weight;
        private Integer percent;
    }

    @Override
    public void onJobStart(SparkListenerJobStart jobStart) {
        int jobId = jobStart.jobId();
        synchronized (jobIdToJobInfo) {
            Integer currentWeight = weight / 5;
            weight -= weight / 5;
            jobIdToJobInfo.put(jobId, new JobInfo(jobId, currentWeight, 0));
        }
    }

    @Override
    public void onJobEnd(SparkListenerJobEnd jobEnd) {
        Integer jobId = jobEnd.jobId();
        logger.info("onJobEnd, jobId={}", jobId);
        JobInfo jobInfo = jobIdToJobInfo.get(jobId);
        jobInfo.setPercent(100);
        Integer currentProgress = 0;
        synchronized (jobIdToJobInfo) {
            Collection<JobInfo> jobInfos = jobIdToJobInfo.values();
            for (JobInfo info : jobInfos) {
                currentProgress += info.getWeight() * info.getPercent() / 100;
            }
            if (currentProgress > progress) {
                progress = currentProgress;
                updateProgress(progress);
            }
        }
    }

    @Override
    public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) {
        try {
            conn.close();
        } catch (SQLException e) {
        }
    }

    public void updateProgress(Integer progress) {
        if (conn != null) {
            String sql = String.format("update aiworks.task_instance set progress = %s where id = %s and progress != 100", progress, taskInstanceId);
            logger.info(sql);
            Statement statement = null;
            try {
                statement = conn.createStatement();
                statement.execute(sql);
                statement.close();
            } catch (SQLException e) {
                logger.error(e.getMessage());
                e.printStackTrace();
            } finally {
                if (null != statement) {
                    try {
                        statement.close();
                    } catch (SQLException e) {
                    }
                }
            }
        } else {
            logger.error("jdbc conn is null");
            System.exit(1);
        }
    }

    @Override
    public void onStageCompleted(SparkListenerStageCompleted stageCompleted) {

    }

    @Override
    public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {

    }

    @Override
    public void onTaskStart(SparkListenerTaskStart taskStart) {

    }

    @Override
    public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) {

    }

    @Override
    public void onTaskEnd(SparkListenerTaskEnd taskEnd) {

    }

    @Override
    public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) {

    }

    @Override
    public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) {

    }

    @Override
    public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) {

    }

    @Override
    public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) {

    }

    @Override
    public void onApplicationStart(SparkListenerApplicationStart applicationStart) {

    }

    @Override
    public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) {

    }

    @Override
    public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {

    }

    @Override
    public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {

    }

    @Override
    public void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) {

    }

    @Override
    public void onExecutorBlacklistedForStage(
            SparkListenerExecutorBlacklistedForStage executorBlacklistedForStage) {

    }

    @Override
    public void onNodeBlacklistedForStage(
            SparkListenerNodeBlacklistedForStage nodeBlacklistedForStage) {

    }

    @Override
    public void onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted executorUnblacklisted) {

    }

    @Override
    public void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) {

    }

    @Override
    public void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) {

    }

    @Override
    public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) {

    }

    @Override
    public void onSpeculativeTaskSubmitted(SparkListenerSpeculativeTaskSubmitted speculativeTask) {

    }

    @Override
    public void onOtherEvent(SparkListenerEvent event) {

    }
}
